#! /usr/bin/env python
## This file is part of biopy.
## Copyright (C) 2010 Joseph Heled
## Author: Joseph Heled <jheled@gmail.com>
## See the files gpl.txt and lgpl.txt for copying conditions.

from __future__ import division

import optparse, sys, os.path

import pylab

from biopy.genericutils import fileFromName

from biopy import INexus, beastLogHelper, treePlotting
from biopy.treeutils import getTreeClades, countNexusTrees, treeHeight

parser = optparse.OptionParser("""%prog [OPTIONS] posterior-trees.nexus fig-file-name""")

parser.add_option("-b", "--burnin", dest="burnin",
                  help="Burn-in amount (percent, default %default)", default = "10")

parser.add_option("-e", "--every", dest="every",
                  help="""thin out - take one tree for every 'e'. Especially \
useful if you run out of memory (default all, i.e. %default)""", default = "1")

parser.add_option("-m", "--mean-tree", dest="mainTree", action="append",
                  help="""Show tree on top of river (the tree from \
starbeast_posterior_popsizes for instance)""", default = None)

parser.add_option("", "--mcolors", dest="mcolors", action="append",
                  help="""colors for mean tree(s).""", default = None) 

parser.add_option("-s", "--single", dest="singleTopology",
                  help="Restrict to a single topology (either the most frequent \
topology, or the one from the mean-tree when given""",
                  action="store_true", default = False)

parser.add_option("-c", "--color", dest="color",
                  help="""Assign tree color based on its topology.""",
                  action="store_true", default = False)

parser.add_option("", "--yclip", dest="yclip", metavar="LEVEL",
                  help="""Clip display of root demographic using upper HPD \
(0 < level < 1)""", default = 1) 

parser.add_option("", "--ychop", dest="ychop", metavar="LEVEL",
                  help="""chop root demographic to fraction of total tree """ + 
                  """ (0 < level)""", default = None) 

parser.add_option("", "--xaxislim", dest="xlim", metavar="min,max",
                  default = None)

parser.add_option("", "--alphafactor", dest="afactor", metavar="F",
                  help="""Reduce this value to lower drawing intensity.""", default = 10.0) 

parser.add_option("", "--tree-spacing", dest="treespacing",
                  help="""Additional spacing (> 0 default %default)""", default = 0.3) 

parser.add_option("", "--positioning", dest="positioning",
                  type='choice', choices=['mean', 'taxonmean', 'between', 'star'],
                  help="Method of positioning internal nodes based on clade \
taxa (mean [default] | taxonmean | between | star).", default = 'mean') 

parser.add_option("", "--fontsize", dest="fontsize", metavar="N",
                  help="""labels font size.""", default = 12) 

parser.add_option("", "--figsize", dest="figsize", metavar="W,H",
                  help="""figure size. Width,Height in Inches.""", default = None) 

parser.add_option("", "--wfactor", dest="wfac", metavar="F",
                  help="""Increase or decrease the spacing between taxa""", default = "1") 

parser.add_option("", "--labelsspacehack", dest="labsphack", metavar="F",
                  help="""""", default = ".01") 


parser.add_option("", "--dpi", dest="dpi", metavar="N",
                  help="""image DPI.""", default = "300") 

parser.add_option("", "--labelsoffset", dest="labeloff", metavar="offset",
                  help="""labels plot offset.""", default = 0.0) 

parser.add_option("-i", "--interactive", dest="interactive", action="store_true",
                  help="""present plot interactively to user..""", default = False) 

parser.add_option("-p", "--progress", dest="progress",
                  help="Print out progress messages to terminal (standard error)",
                  action="store_true", default = False)

options, args = parser.parse_args()

progress = options.progress
every = int(options.every)
burnIn =  float(options.burnin)/100.0
colorTops = options.color
alphaFactor = float(options.afactor)

if options.positioning == 'mean' :
  positioning = treePlotting.DescendantMean
elif options.positioning == 'taxonmean':
   positioning = treePlotting.DescendantWeightedMean
elif options.positioning == 'between':
   positioning = treePlotting.DescendantBetween
elif options.positioning == 'star':
   positioning = None
else:
  raise

if len(args) != 2 :
  parser.print_help(sys.stderr)
  sys.exit(1)

mainTree = []
if options.mainTree :
  mainTree = [INexus.Tree(t) for t in options.mainTree]
  has = beastLogHelper.setDemographics(mainTree)
  if not all(has):
    print >> sys.stderr, "Main tree(s) missing population size data"
    sys.exit(1)
  
nexusTreesFileName = args[0]

try :
  nexFile = fileFromName(nexusTreesFileName)
except Exception,e:
  # report error
  print >> sys.stderr, "Error:", e.message
  sys.exit(1)

if progress:
  print >> sys.stderr, "counting trees ...,",
  
nTrees = countNexusTrees(nexusTreesFileName)

nexusReader = INexus.INexus()

nBurninTrees = int(burnIn*nTrees)

if progress:
  print >> sys.stderr, "reading %d trees ...," % ((nTrees - nBurninTrees)//every),

trees = []
for tree in nexusReader.read(nexFile, slice(nBurninTrees, -1, every)) :
  has = beastLogHelper.setDemographics([tree])     ; assert has[0]
  trees.append(tree)
  
if progress:
  print >> sys.stderr, "preparing ...,",

wd = treePlotting.getSpacing(trees, float(options.treespacing)) * float(options.wfac)
print wd
sys.exit(1)

ref = mainTree[0] if mainTree else None
taxOrder = treePlotting.getTaxaOrder(trees, refTree = ref,
                                     reportTopologies = colorTops) 
oo,refTree = taxOrder[:2]
xs = dict([(t,k*wd) for k,t in enumerate(oo)])

cd = None
if options.singleTopology:
  d1 = dict([(n.id,(c,n)) for c,n in getTreeClades(refTree, True)])
  cd = dict([(frozenset(c), (frozenset(d1[n.succ[0]][0]), frozenset(d1[n.succ[1]][0]))
              if n.succ else None)  for c,n in d1.values()])

if progress:
  print >> sys.stderr, "plotting ...,",

pylab.ioff()
if options.figsize is not None :
  fs = [float(x) for x in options.figsize.split(',')]
  fig = pylab.figure(figsize=fs)
else :  
  fig = pylab.figure()

pheights = []
heights = []
# matplotlib has only 8 bit support for alpha
alpha = min(max(alphaFactor/len(trees), 1/500.), .3)

if colorTops :
  if 1 :
    import colorsys
    tops = taxOrder[2].items()
    tops.sort(key = lambda x : len(x[1]))
    # main is limelikesort of green
    mainHue = 100./256
    c0 = (mainHue + len(tops[-1][1])/len(trees)/2) % 1
    for top in tops:
      topPercent = len(top[1])/len(trees)
      c2 = (c0 + topPercent/2) % 1
      # full saturatuion, but why v=184??
      col = colorsys.hsv_to_rgb(c2,1.0,184/256.)
      #print top[0],len(top[1]), c2, [x*256 for x in col]
      for k in top[1]:
        tree = trees[k]
        for x in tree.get_terminals() :
          tree.node(x).data.x = xs[tree.node(x).data.taxon]
        h = treePlotting.drawTree(tree, tree.root, cd, positioning = positioning,
                                  generalPlotAttributes = {'color' : col, 'alpha' : alpha})
        pheights.append(h)
        heights.append(treeHeight(tree))
      c0 = (c0 + topPercent) % 1
  else :
    import colorsys
    tops = taxOrder[2].items()
    tops.sort(key = lambda x : len(x[1]))
    # main is limelikesort of green
    mainHue = 100./256

    radMain = len(tops[-1][1])/len(trees)/2
    c0r = (mainHue + radMain) % 1
    c0l = (mainHue - radMain) % 1
    print mainHue, radMain, c0r,c0l
    
    for top in tops:
      topPercent = len(top[1])/len(trees)
      c2 = (c0 + topPercent/2) % 1
      # full saturatuion, but why v=184??
      col = colorsys.hsv_to_rgb(c2,1.0,184/256.)
      #print top[0],len(top[1]), c2, [x*256 for x in col]
      for k in top[1]:
        tree = trees[k]
        for x in tree.get_terminals() :
          tree.node(x).data.x = xs[tree.node(x).data.taxon]
        h = treePlotting.drawTree(tree, tree.root, cd, positioning = positioning,
                                  generalPlotAttributes = {'color' : col, 'alpha' : alpha})
        pheights.append(h)
        heights.append(treeHeight(tree))
      c0 = (c0 + topPercent) % 1
    
else :
  dtrees = trees

  for tree in dtrees:
    for x in tree.get_terminals() :
       tree.node(x).data.x = xs[tree.node(x).data.taxon]
    h = treePlotting.drawTree(tree, tree.root, cd, positioning = positioning,
                              generalPlotAttributes = {'color' : "lime", 'alpha' : alpha})
    pheights.append(h)
    heights.append(treeHeight(tree))

from biopy.bayesianStats import hpd

yc = float(options.yclip)
yclip = 0 < yc < 1
if yclip :
  hmin = hpd(pheights, yc)[1]
else :
  hmin = 0

if options.ychop is not None :
  from numpy import mean
  level = max(float(options.ychop), 0)
  chop = mean(heights) * (1+level)

if len(mainTree) :
  if options.mcolors is not None :
    colors = options.mcolors
  else :
    colors = ["red", "blue", "green"]

  import itertools
  for tree,col in zip(mainTree, itertools.cycle(colors)):
    for x in tree.get_terminals() :
       tree.node(x).data.x = xs[tree.node(x).data.taxon]
    h = treePlotting.drawTree(tree, tree.root, None,
                              positioning = positioning, fill=None,
                              generalPlotAttributes = {'color' : col})
    hmin = max(h, hmin)

labels = []
for ni in tree.get_terminals():
  node = tree.node(ni)
  t = pylab.text(xs[node.data.taxon], -float(options.labeloff), node.data.taxon,
                 fontsize = float(options.fontsize), va='top', ha='center',
                 rotation="vertical")
  labels.append(t)
  
if options.ychop is not None:
  ymax = chop
else :
  ymax = hmin if yclip else pylab.ylim()[1]


if 0 :
  pylab.draw()

  m = 0
  for label in labels:
    bbox = label.get_window_extent(dpi=300)
    # the figure transform goes from relative coords->pixels and we
    # want the inverse of that
    bboxi = bbox.inverse_transformed(fig.transFigure)
    (x0,y0),(x1,y1) = bboxi.get_points()
    m = max(m , y1 - y0)
    # print bbox, bboxi, m , y1-y0, (x0,y0),(x1,y1)
  
pylab.ylim((pylab.ylim()[0], ymax))

def adjustForlabels(labels) :
  ax = pylab.axes()
  tt = ax.transData.inverted()
  ymin,ymax = pylab.ylim()
  # print ymin
  for t in labels:
    ex = t.get_window_extent()
    ymin = min(ymin, 1.3*min([tt.transform(z)[1] for z in (ex.p0,ex.p1)]))
  pylab.ylim((ymin, ymax))
  
try : 
  pylab.draw()
  adjustForlabels(labels)
except RuntimeError,e:
  # no renderer??
  nTxt = max([len(t.get_text()) for t in labels])
  # A total hack for an annoying bug (text ends up outside the frame)
  pylab.ylim((-ymax*nTxt*float(options.labsphack), ymax))

if options.xlim :
  pylab.xlim([float(x) for x in options.xlim.split(',')])
  
if options.interactive:
  pylab.ion()
  pylab.show()

pylab.savefig(args[1], dpi=float(options.dpi))

if progress:
  print >>  sys.stderr, "done." 
